#!/usr/bin/env python3

#==============================================================================
# Functionality
#==============================================================================
import pdb
import sys
import os
import re

# utility funcs, classes, etc go here.

def asserting(cond):
    if not cond:
        pdb.set_trace()
    assert(cond)

def has_stdin():
    return not sys.stdin.isatty()

def reg(pat, flags=0):
    return re.compile(pat, re.VERBOSE | flags)

def matches(s, pat, *, case='smart', verbose=False):
    inverted = True
    if len(pat) > 0 and pat[0] == "!":
        inverted = False
        pat = pat[1:]
    flags = 0
    if case == 'off':
        if verbose: print('case==of')
        flags = re.IGNORECASE
    elif case == 'on':
        if verbose: print('case==on')
        pass
    elif case == 'smart': # smart case
        if verbose: print('case==smart')
        if not any(x.isupper() for x in pat):
            if verbose: print('smartcase-ignore')
            flags = re.IGNORECASE
    else:
        raise ValueError('unknown case specification %s' % case)
    x = reg('(?:.*)' + pat, flags).match(s)
    return inverted == (not (not x))

def narrow(strings, searches, **kws):
  if isinstance(strings, str):
    return len(narrow([strings], searches, **kws)) > 0
  if isinstance(searches, str):
    searches = [searches]
  if searches is None:
    searches = []
  for search in searches:
    strings = [s for s in strings if matches(s, search, **kws)]
  return strings



#==============================================================================
# Cmdline
#==============================================================================
import argparse

parser = argparse.ArgumentParser(formatter_class=argparse.RawTextHelpFormatter, 
    description="""
TODO
""")
     
parser.add_argument('checkpoint',
    help="The path to the Tensorflow checkpoint. E.g. gs://mlpublic-euw4/runs/bigrun97/dec28/run6_evos0_imagenet_dlrmul_0_4/model.ckpt-9658250" )
     
parser.add_argument('-v', '--verbose',
    action="store_true",
    help="verbose output" )
     
parser.add_argument('-d', '--dry-run',
    action="store_true",
    help="Just list the tensors that would be exported; don't actually download any of them." )
     
parser.add_argument('-e', '--pattern', nargs='*',
    help="only export layers that match this (python-style) regexp." )

args = None

#==============================================================================
# Main
#==============================================================================

import json
import tensorflow as tf
import os
import tqdm
from natsort import natsorted


def maketree(path):
  try:
    os.makedirs(path)
  except:
    pass

def numpy2json(arr):
  return json.dumps(arr.tolist())

def tensor2json(ckpt, outdir, name):
  outpath = os.path.join(outdir, name.replace('/', os.path.sep))
  maketree(outpath)
  outpath = os.path.join(outpath, 'data.json')
  if os.path.exists(outpath) and os.path.getsize(outpath) > 0:
    return
  with open(outpath, 'w') as f:
    tensor = ckpt.get_tensor(name)
    json = numpy2json(tensor)
    print(json, file=f)

from contextlib import contextmanager

@contextmanager
def nullcontext():
  yield


def maybe_open(path, *argz, **kws):
  if args.dry_run:
    return nullcontext()
  return open(path, *argz, **kws)

def ckpt2json(checkpoint_path):
  ckpt = tf.train.load_checkpoint(checkpoint_path)
  #outdir = os.path.basename(checkpoint_path)
  outdir = checkpoint_path.replace('://', '/')
  outdir = outdir.replace('/', os.path.sep)
  maketree(outdir)
  shapes = ckpt.get_variable_to_shape_map()
  types = {variable: type.name for variable, type in ckpt.get_variable_to_dtype_map().items()}
  names = list(natsorted([name for name, shape in shapes.items()]))
  with maybe_open(os.path.join(outdir, 'shapes.jsonl'), 'w') as f:
    for name in names:
      if f is None: continue
      print(json.dumps([name, shapes[name]]), file=f)
  with maybe_open(os.path.join(outdir, 'types.jsonl'), 'w') as f:
    for name in names:
      if f is None: continue
      print(json.dumps([name, types[name]]), file=f)
  with tqdm.tqdm(narrow(names, args.pattern)) as pbar:
    for name in pbar:
      args.log = lambda x: pbar.write(x)
      args.log('{shape!r:>24} {name!r} -> {path}'.format(name=name, path=os.path.join(outdir, name, 'data.json'), shape=shapes[name], type=types[name]))
      if not args.dry_run:
        tensor2json(ckpt, outdir, name)

def run():
    if args.verbose:
        print(args)
    ckpt2json(args.checkpoint)

def main():
    try:
        global args
        if not args:
            args = parser.parse_args()
        return run()
    except IOError:
        # http://stackoverflow.com/questions/15793886/how-to-avoid-a-broken-pipe-error-when-printing-a-large-amount-of-formatted-data
        try:
            sys.stdout.close()
        except IOError:
            pass
        try:
            sys.stderr.close()
        except IOError:
            pass

if __name__ == "__main__":
    main()

